{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fc7b4f18-60aa-4c27-a25c-79c69cf7b89f",
   "metadata": {},
   "source": [
    "# Loading and running a trained model with PyTorch\n",
    "\n",
    "This notebook walks you through how to load and run a trained model using PyTorch. You can run each cell one by one as you work through the content, or run all cells and read through all output at the same time."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a66123",
   "metadata": {},
   "source": [
    "In the previous notebook, `torch-use-gpu.ipynb`, we performed the following tasks:\n",
    "\n",
    "1. Imported Torch libraries\n",
    "1. Listed available GPUs.\n",
    "1. Checked that GPUs are enabled.\n",
    "1. Assigned a GPU device and retrieved the device name.\n",
    "1. Loaded vectors, matrices, and data onto a GPU.\n",
    "1. Loaded a neural network model onto a GPU.\n",
    "1. Trained the neural network model.\n",
    "\n",
    "Make sure that you have run the `torch-use-gpu.ipynb` notebook file before working through this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "075822ba",
   "metadata": {},
   "source": [
    "In this notebook, you will:\n",
    "\n",
    "1. Import libraries and GPUs into your notebook environment.\n",
    "1. Load a model and training data onto a GPU.\n",
    "1. Run a model on a GPU.\n",
    "1. Review model results."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7ea910b",
   "metadata": {},
   "source": [
    "## Run the model with your trained weights\n",
    "\n",
    "The first thing to to is set up the neural network model to run using the weights you created when you trained the model in the `torch-use-gpu.ipynb` file."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69756c91",
   "metadata": {},
   "source": [
    "### Set up your environment\n",
    "\n",
    "Import the torch and torchvision libraries you need in order to work with PyTorch, and ensure that your GPU resources are visible in your notebook server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb20299-9bb5-4851-b7b4-e32691d4c69b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torchvision==0.9.1\n",
    "!pip install tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b377805e-2ec1-4905-9f73-12bbc9750e76",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aacd0b93-5c6e-4864-8f6e-85bc10d5a132",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)  # let's see what device we got"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "127710c2",
   "metadata": {},
   "source": [
    "### Load the model and trained weights onto the GPU\n",
    "\n",
    "Run the following commands to set up a simple neural network model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddda446-80c9-4103-b5f8-b424a3347742",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Getting set to import our trained model.\n",
    "\n",
    "# batch size of 1 so we can look at one image at time.\n",
    "batch_size = 1\n",
    "\n",
    "\n",
    "class SimpleNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SimpleNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 784)\n",
    "        self.fc2 = nn.Linear(784, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(batch_size, -1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ced04b88",
   "metadata": {},
   "source": [
    "Then, load the model onto the GPU along with the trained weights from the `torch-use-gpu.ipynb` workbook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4c24aa3-0b70-4290-8e70-c1ef30775f76",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SimpleNet().to( device )\n",
    "model.load_state_dict( torch.load(\"mnist_fashion_SimpleNet.pt\") )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de877e72-f731-42c1-9fb6-2073af0c1fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a8c1ad-54f8-4e3b-a2c3-33ac10b57396",
   "metadata": {},
   "outputs": [],
   "source": [
    "# using our loader as before - will have that same batch size of 1.\n",
    "\n",
    "test_transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,)),\n",
    "        ])\n",
    "\n",
    "\n",
    "testset = datasets.FashionMNIST('./data', train=False,\n",
    "                                transform=test_transform)\n",
    "test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                          shuffle=False, num_workers=2)\n",
    "\n",
    "testing_iterator = iter(test_loader)  # create an iterator of our loader\n",
    "\n",
    "# A dictionary to map our class numbers to their items.\n",
    "labels_map = {\n",
    "    0: \"T-Shirt\",\n",
    "    1: \"Trouser\",\n",
    "    2: \"Pullover\",\n",
    "    3: \"Dress\",\n",
    "    4: \"Coat\",\n",
    "    5: \"Sandal\",\n",
    "    6: \"Shirt\",\n",
    "    7: \"Sneaker\",\n",
    "    8: \"Bag\",\n",
    "    9: \"Ankle Boot\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f3e8ea3-b322-485c-a13d-c31e36d542e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_probs(model, x, y, device):\n",
    "    \"\"\" Function to return the current probabilities of the classes given the model \"\"\"\n",
    "    model = model.to(device)\n",
    "    x = x.to(device)\n",
    "    y = y.to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        output = model(x)  # model classifies the input\n",
    "        output = output.squeeze()\n",
    "        \n",
    "    # Return the list of probabilites and the probability of ground truth class\n",
    "    return output, output[y]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b842bd7-66a1-4feb-a0ee-45b3c4d09a82",
   "metadata": {},
   "source": [
    "### Run the model and check the results\n",
    "\n",
    "With the model and trained weights imported, we can now run our model and see how well it performs.\n",
    "\n",
    "Run the below cell repeatedly to see if the model predicts the selected image correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d4b137b-5901-4d56-81b3-762257a23857",
   "metadata": {},
   "outputs": [],
   "source": [
    "image, label = testing_iterator.next()  # get one datapoint\n",
    "image = image.squeeze(0)  # 'squeeze' out the tensor from the outer list.\n",
    "print(\"image shape is\", image.shape)\n",
    "\n",
    "probs, gt_prob = get_probs(model, image, label, device)  # run model prediction\n",
    "\n",
    "\n",
    "# plot our image\n",
    "plt.imshow(image.view(28,28), cmap=\"gray\") \n",
    "plt.show()\n",
    "\n",
    "# what did the model predict?\n",
    "max_label = torch.argmax(probs)\n",
    "print(\"Model predicted:\", labels_map[max_label.item()], \"with confidence % of\", probs[max_label.item()].item())\n",
    "print(\"Correct label is:\", labels_map[label.item()])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8500d806",
   "metadata": {},
   "source": [
    "The model predicts the category that the photographed object should be in, and lists its confidence percentage beside the prediction."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
